/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"

#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"

#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"

////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace device {

/////////////////////////////////////////////////////////////////////////////////////////////////

/*!
  The universal GEMM accommodates serial reductions, parallel reductions,
  batched strided, and batched array variants.
*/
template <
        /// Element type for A matrix operand
        typename ElementA_,
        /// Layout type for A matrix operand
        typename LayoutA_,
        /// Element type for B matrix operand
        typename ElementB_,
        /// Layout type for B matrix operand
        typename LayoutB_,
        /// Element type for C and D matrix operands
        typename ElementC_,
        /// Layout type for C and D matrix operands
        typename LayoutC_,
        /// Element type for internal accumulation
        typename ElementAccumulator_ = ElementC_,
        /// Operator class tag
        typename OperatorClass_ = arch::OpClassSimt,
        /// Tag indicating architecture to tune for
        typename ArchTag_ = arch::Sm70,
        /// Threadblock-level tile size (concept: GemmShape)
        typename ThreadblockShape_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::ThreadblockShape,
        /// Warp-level tile size (concept: GemmShape)
        typename WarpShape_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::WarpShape,
        /// Instruction-level tile size (concept: GemmShape)
        typename InstructionShape_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::InstructionShape,
        /// Epilogue output operator
        typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::EpilogueOutputOp,
        /// Threadblock-level swizzling operator
        typename ThreadblockSwizzle_ =
                threadblock::GemmIdentityThreadblockSwizzle<>,
        /// Number of stages used in the pipelined mainloop
        int Stages = DefaultGemmConfiguration<OperatorClass_, ArchTag_,
                                              ElementA_, ElementB_, ElementC_,
                                              ElementAccumulator_>::kStages,
        /// Access granularity of A matrix in units of elements
        int AlignmentA = DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::kAlignmentA,
        /// Access granularity of B matrix in units of elements
        int AlignmentB = DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::kAlignmentB,
        /// Operation performed by GEMM
        typename Operator_ = typename DefaultGemmConfiguration<
                OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
                ElementAccumulator_>::Operator,
        /// Complex elementwise transformation on A operand
        ComplexTransform TransformA = ComplexTransform::kNone,
        /// Complex elementwise transformation on B operand
        ComplexTransform TransformB = ComplexTransform::kNone>
class GemmUniversal
        : GemmUniversalBase<typename kernel::DefaultGemmUniversal<
                  ElementA_, LayoutA_, TransformA, AlignmentA, ElementB_,
                  LayoutB_, TransformB, AlignmentB, ElementC_, LayoutC_,
                  ElementAccumulator_, OperatorClass_, ArchTag_,
                  ThreadblockShape_, WarpShape_, InstructionShape_,
                  EpilogueOutputOp_, ThreadblockSwizzle_, Stages,
                  Operator_>::GemmKernel> {
public:
    using ElementAccumulator = ElementAccumulator_;
    using OperatorClass = OperatorClass_;
    using ArchTag = ArchTag_;
    using ThreadblockShape = ThreadblockShape_;
    using WarpShape = WarpShape_;
    using InstructionShape = InstructionShape_;
    using EpilogueOutputOp = EpilogueOutputOp_;
    using ThreadblockSwizzle = ThreadblockSwizzle_;
    using Operator = Operator_;
    static int const kStages = Stages;
    static int const kAlignmentA = AlignmentA;
    static int const kAlignmentB = AlignmentB;
    static int const kAlignmentC = EpilogueOutputOp::kCount;
    static ComplexTransform const kTransformA = TransformA;
    static ComplexTransform const kTransformB = TransformB;

    using Base = GemmUniversalBase<typename kernel::DefaultGemmUniversal<
            ElementA_, LayoutA_, TransformA, AlignmentA, ElementB_, LayoutB_,
            TransformB, AlignmentB, ElementC_, LayoutC_, ElementAccumulator_,
            OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_,
            InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages,
            Operator_>::GemmKernel>;

    using Arguments = typename Base::Arguments;
    using GemmKernel = typename Base::GemmKernel;
};

////////////////////////////////////////////////////////////////////////////////

/// Parital specialization for column-major output exchanges problem size and
/// operand.
template <
        /// Element type for A matrix operand
        typename ElementA_,
        /// Layout type for A matrix operand
        typename LayoutA_,
        /// Element type for B matrix operand
        typename ElementB_,
        /// Layout type for B matrix operand
        typename LayoutB_,
        /// Element type for C and D matrix operands
        typename ElementC_,
        /// Element type for internal accumulation
        typename ElementAccumulator_,
        /// Operator class tag
        typename OperatorClass_,
        /// Tag indicating architecture to tune for
        typename ArchTag_,
        /// Threadblock-level tile size (concept: GemmShape)
        typename ThreadblockShape_,
        /// Warp-level tile size (concept: GemmShape)
        typename WarpShape_,
        /// Instruction-level tile size (concept: GemmShape)
        typename InstructionShape_,
        /// Epilogue output operator
        typename EpilogueOutputOp_,
        /// Threadblock-level swizzling operator
        typename ThreadblockSwizzle_,
        /// Number of stages used in the pipelined mainloop
        int Stages,
        /// Access granularity of A matrix in units of elements
        int AlignmentA,
        /// Access granularity of B matrix in units of elements
        int AlignmentB,
        /// Operation performed by GEMM
        typename Operator_,
        /// Complex elementwise transformation on A operand
        ComplexTransform TransformA,
        /// Complex elementwise transformation on B operand
        ComplexTransform TransformB>
class GemmUniversal<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
                    layout::ColumnMajor,  // partially specialized on LayoutC
                    ElementAccumulator_, OperatorClass_, ArchTag_,
                    ThreadblockShape_, WarpShape_, InstructionShape_,
                    EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA,
                    AlignmentB, Operator_, TransformA, TransformB> {
public:
    using ElementA = ElementA_;
    using LayoutA = LayoutA_;
    using TensorRefA = TensorRef<ElementA const, LayoutA>;
    using ElementB = ElementB_;
    using LayoutB = LayoutB_;
    using TensorRefB = TensorRef<ElementB const, LayoutB>;
    using ElementC = ElementC_;
    using LayoutC = layout::ColumnMajor;
    using TensorRefC = TensorRef<ElementC const, LayoutC>;
    using TensorRefD = TensorRef<ElementC, LayoutC>;
    using ElementAccumulator = ElementAccumulator_;
    using OperatorClass = OperatorClass_;
    using ArchTag = ArchTag_;
    using ThreadblockShape = ThreadblockShape_;
    using WarpShape = WarpShape_;
    using InstructionShape = InstructionShape_;
    using EpilogueOutputOp = EpilogueOutputOp_;
    using ThreadblockSwizzle = ThreadblockSwizzle_;
    using Operator = Operator_;
    static int const kStages = Stages;
    static int const kAlignmentA = AlignmentA;
    static int const kAlignmentB = AlignmentB;
    static ComplexTransform const kTransformA = TransformA;
    static ComplexTransform const kTransformB = TransformB;

    using UnderlyingOperator = typename GemmUniversal<
            ElementB, typename layout::LayoutTranspose<LayoutB>::type, ElementA,
            typename layout::LayoutTranspose<LayoutA>::type, ElementC,
            layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag,
            ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
            ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, Operator,
            kTransformB, kTransformA>::Base;

    using GemmKernel = typename UnderlyingOperator::GemmKernel;
    static int const kAlignmentC = EpilogueOutputOp::kCount;

    /// Argument structure
    using Arguments = typename UnderlyingOperator::Arguments;

private:
    UnderlyingOperator underlying_operator_;

public:
    /// Constructs the GEMM.
    GemmUniversal() {}

    /// Helper to construct a transposed equivalent for the underying GEMM
    /// operator
    static Arguments to_underlying_arguments(Arguments const& args) {
        return args.transposed_problem();
    }

    /// Determines whether the GEMM can execute the given problem.
    static Status can_implement(Arguments const& args) {
        return UnderlyingOperator::can_implement(to_underlying_arguments(args));
    }

    /// Gets the workspace size
    static size_t get_workspace_size(Arguments const& args) {
        return UnderlyingOperator::get_workspace_size(
                to_underlying_arguments(args));
    }

    /// Computes the grid shape
    static dim3 get_grid_shape(Arguments const& args) {
        return UnderlyingOperator::get_grid_shape(
                to_underlying_arguments(args));
    }

    /// Computes the maximum number of active blocks per multiprocessor
    static int maximum_active_blocks(int smem_capacity = -1) {
        return UnderlyingOperator::maximum_active_blocks(smem_capacity);
    }

    /// Initializes GEMM state from arguments.
    Status initialize(Arguments const& args, void* workspace = nullptr,
                      cudaStream_t stream = nullptr) {
        return underlying_operator_.initialize(to_underlying_arguments(args),
                                               workspace, stream);
    }

    /// Lightweight update given a subset of arguments
    Status update(Arguments const& args, void* workspace = nullptr) {
        return underlying_operator_.update(to_underlying_arguments(args),
                                           workspace);
    }

    /// Runs the kernel using initialized state.
    Status run(cudaStream_t stream = nullptr) {
        return underlying_operator_.run(stream);
    }

    /// Runs the kernel using initialized state.
    Status operator()(cudaStream_t stream = nullptr) { return run(stream); }

    /// Runs the kernel using initialized state.
    Status operator()(Arguments const& args, void* workspace = nullptr,
                      cudaStream_t stream = nullptr) {
        Status status = initialize(args, workspace, stream);

        if (status == Status::kSuccess) {
            status = run(stream);
        }

        return status;
    }
};

////////////////////////////////////////////////////////////////////////////////

}  // namespace device
}  // namespace gemm
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////
